Need GPU warning¶

Running this mri-nufft example requires a GPU, and hence is NOT possible on binder currently We request you to kindly run this notebook on Google Colab by clicking the link below. Additionally, please make sure to set the runtime on Colab to use a GPU and install the below libraries before running.

Open In Colab
    

Simple UNet model.¶

This model is a simplified version of the U-Net architecture, which is widely used for image segmentation tasks. This is implemented in the proprietary FASTMRI package [fastmri]_.

The U-Net model consists of an encoder (downsampling path) and a decoder (upsampling path) with skip connections between corresponding layers in the encoder and decoder. These skip connections help in retaining spatial information that is lost during the downsampling process.

The primary purpose of this model is to perform image reconstruction tasks, specifically for MRI images. It takes an input MRI image and reconstructs it to improve the image quality or to recover missing parts of the image.

This implementation of the UNet model was pulled from the FastMRI Facebook repository, which is a collaborative research project aimed at advancing the field of medical imaging using machine learning techniques.

\begin{align}\mathbf{\hat{x}} = \mathrm{arg} \min_{\mathbf{x}} || \mathcal{U}_\mathbf{\theta}(\mathbf{y}) - \mathbf{x} ||_2^2\end{align}

where $\mathbf{\hat{x}}$ is the reconstructed MRI image, $\mathbf{x}$ is the ground truth image, $\mathbf{y}$ is the input MRI image (e.g., k-space data), and $\mathcal{U}_\mathbf{\theta}$ is the U-Net model parameterized by $\theta$.

Warning

We train on a single image here. In practice, this should be done on a database like fastMRI [fastmri]_.

In [ ]:
# Install libraries
!pip install mri-nufft[gpunufft] scikit-image fastmri
!pip install brainweb-dl  # Required for data

Imports

In [121]:
import os
from pathlib import Path
import shutil
import brainweb_dl as bwdl
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
import time
import joblib
from PIL import Image
import tempfile as tmp

from fastmri.models import Unet
from mrinufft import get_operator
from mrinufft.trajectories import initialize_2D_cones, initialize_2D_radial, initialize_2D_spiral

Setup a simple class for the U-Net model

In [122]:
class Model(torch.nn.Module):
    """Model for MRI reconstruction using a U-Net."""

    def __init__(self, initial_trajectory):
        super().__init__()
        self.operator = get_operator("gpunufft", wrt_data=True)(
            initial_trajectory,
            shape=(256, 256),
            density=True,
            squeeze_dims=False,
        )
        self.unet = Unet(in_chans=1, out_chans=1, chans=32, num_pool_layers=4)

    def forward(self, kspace):
        """Forward pass of the model."""
        image = self.operator.adj_op(kspace)
        recon = self.unet(image.float()).abs()
        recon /= torch.mean(recon)
        return recon

Utility function to plot the state of the model

In [123]:
def plot_state(axs, mri_2D, traj, recon, loss=None, save_name=None):
    """Image plotting function.

    Plot the original MRI image, the trajectory, the reconstructed image,
    and the loss curve (if provided). Saves the plot if a filename is provided.

    Parameters
    ----------
    axs (numpy array): Array of matplotlib axes to plot on.
    mri_2D (torch.Tensor): Original MRI image.
    traj : Trajectory.
    recon (torch.Tensor): Reconstructed image after training.
    loss (list, optional): List of loss values to plot. Defaults to None.
    save_name (str, optional): Filename to save the plot. Defaults to None.
    """
    axs = axs.flatten()
    axs[0].imshow(np.abs(mri_2D[0]), cmap="gray")
    axs[0].axis("off")
    axs[0].set_title("MR Image")
    axs[1].scatter(*traj.T, s=0.5)
    axs[1].set_title("Trajectory")
    axs[2].imshow(np.abs(recon[0][0].detach().cpu().numpy()), cmap="gray")
    axs[2].axis("off")
    axs[2].set_title("Reconstruction")
    if loss is not None:
        axs[3].plot(loss)
        axs[3].grid("on")
        axs[3].set_title("Loss")
    if save_name is not None:
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

Setup Inputs (models, trajectory and image)

Q1  

  • Change the initialization in terms of trajectory (2D cones, 2D spiral arches, 2D radial spokes). Observe the impact on the image quality?

$\leadsto$   After 100 epochs, the initialization 2D spiral arches in terms of trajectory achieves the best image quality in reconstruction.

2D cones

image.png

2D spiral arches

image.png

2D radial spokes

image.png

Q2  

  • Change the number of epochs during training ($10$ vs $100$). Observe the impact on the image quality?

$\leadsto$   When the number of epochs during the training is set to 10, the reconstructed image quality is worse than that of 100 epochs.

2D radial spokes 10 epochs

image.png

2D cones
In [124]:
init_traj = initialize_2D_cones(32, 256).reshape(-1, 2).astype(np.float32)
model = Model(init_traj)
model.eval()
Out[124]:
Model(
  (operator): MRINufftAutoGrad()
  (unet): Unet(
    (down_sample_layers): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
    )
    (conv): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2d(p=0.0, inplace=False)
        (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): LeakyReLU(negative_slope=0.2, inplace=True)
        (7): Dropout2d(p=0.0, inplace=False)
      )
    )
    (up_conv): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): Sequential(
        (0): ConvBlock(
          (layers): Sequential(
            (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.2, inplace=True)
            (3): Dropout2d(p=0.0, inplace=False)
            (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (6): LeakyReLU(negative_slope=0.2, inplace=True)
            (7): Dropout2d(p=0.0, inplace=False)
          )
        )
        (1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (up_transpose_conv): ModuleList(
      (0): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (1): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (2): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (3): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
    )
  )
)

Get the image on which we will train our U-Net Model

In [125]:
mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))[
    None
]
mri_2D = mri_2D / torch.mean(mri_2D)
kspace_mri_2D = model.operator.op(mri_2D)

# Before training, here is the simple reconstruction we have using a
# density compensated adjoint.
dc_adjoint = model.operator.adj_op(kspace_mri_2D)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
plot_state(axs, mri_2D, init_traj, dc_adjoint)
No description has been provided for this image

Start training loop

In [127]:
num_epochs = 100
optimizer = torch.optim.RAdam(model.parameters(), lr=1e-3)
losses = []  # Store the loss values and create an animation
image_files = []  # Store the images to create a gif
model.train()

with tqdm(range(num_epochs), unit="steps") as tqdms:
    for i in tqdms:
        out = model(kspace_mri_2D)  # Forward pass

        loss = torch.nn.functional.l1_loss(out, mri_2D[None])  # Compute loss
        tqdms.set_postfix({"loss": loss.item()})  # Update progress bar
        losses.append(loss.item())  # Store loss value

        optimizer.zero_grad()  # Zero gradients
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        # Generate images for gif
        hashed = joblib.hash((i, "learn_traj", time.time()))
        filename = f"{tmp.NamedTemporaryFile().name}.png"
        fig, axs = plt.subplots(2, 2, figsize=(10, 10))
        plot_state(
            axs,
            mri_2D,
            init_traj,
            out,
            losses,
            save_name=filename,
        )
        image_files.append(filename)


# Make a GIF of all images.
imgs = [Image.open(img) for img in image_files]
imgs[0].save(
    "mrinufft_learn_unet.gif",
    save_all=True,
    append_images=imgs[1:],
    optimize=False,
    duration=2,
    loop=0,
)

# sphinx_gallery_thumbnail_path = 'generated/autoexamples/GPU/images/mrinufft_learn_unet.gif'
100%|██████████| 100/100 [01:12<00:00,  1.37steps/s, loss=0.0608]

.. image-sg:: /generated/autoexamples/GPU/images/mrinufft_learn_unet.gif :alt: example learn_samples :srcset: /generated/autoexamples/GPU/images/mrinufft_learn_unet.gif :class: sphx-glr-single-img

mrinufft_learn_unet_cones.gif

Reconstruction from partially trained U-Net model

In [128]:
model.eval()
new_recon = model(kspace_mri_2D)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
plot_state(axs, mri_2D, init_traj, new_recon, losses)
plt.show()
No description has been provided for this image
In [129]:
print(mri_2D.shape)             # MR image
print(new_recon.shape)          # Reconstruction
torch.Size([1, 256, 256])
torch.Size([1, 1, 256, 256])
In [130]:
print(mri_2D.squeeze().numpy().shape) 
print(new_recon.squeeze().detach().numpy().shape)  
(256, 256)
(256, 256)
In [131]:
from modopt.math.metrics import ssim                            # 计算结构相似度 SSIM
In [132]:
mri_2D_cones = mri_2D.squeeze().numpy()
new_recon_2D_cones = new_recon.squeeze().detach().numpy()
SSIM_cones = ssim((new_recon_2D_cones), (mri_2D_cones))
In [133]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Left
axes[0].imshow(np.abs(new_recon_2D_cones), cmap='gray')
axes[0].set_title(f"Reconstruction 2D cones  : 100 epochs")
axes[0].axis('off')

# Right
axes[1].imshow(np.abs(mri_2D_cones), cmap='gray')
axes[1].set_title(f'mri_2D_cones : SSIM = {SSIM_cones:.3f}')
axes[1].axis('off')

plt.tight_layout()
plt.show()
No description has been provided for this image
2D spiral arches
In [134]:
init_traj = initialize_2D_spiral(32, 256).reshape(-1, 2).astype(np.float32)
model = Model(init_traj)
model.eval()
Out[134]:
Model(
  (operator): MRINufftAutoGrad()
  (unet): Unet(
    (down_sample_layers): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
    )
    (conv): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2d(p=0.0, inplace=False)
        (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): LeakyReLU(negative_slope=0.2, inplace=True)
        (7): Dropout2d(p=0.0, inplace=False)
      )
    )
    (up_conv): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): Sequential(
        (0): ConvBlock(
          (layers): Sequential(
            (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.2, inplace=True)
            (3): Dropout2d(p=0.0, inplace=False)
            (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (6): LeakyReLU(negative_slope=0.2, inplace=True)
            (7): Dropout2d(p=0.0, inplace=False)
          )
        )
        (1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (up_transpose_conv): ModuleList(
      (0): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (1): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (2): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (3): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
    )
  )
)

Get the image on which we will train our U-Net Model

In [135]:
mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))[
    None
]
mri_2D = mri_2D / torch.mean(mri_2D)
kspace_mri_2D = model.operator.op(mri_2D)

# Before training, here is the simple reconstruction we have using a
# density compensated adjoint.
dc_adjoint = model.operator.adj_op(kspace_mri_2D)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
plot_state(axs, mri_2D, init_traj, dc_adjoint)
No description has been provided for this image

Start training loop

In [136]:
num_epochs = 100
optimizer = torch.optim.RAdam(model.parameters(), lr=1e-3)
losses = []  # Store the loss values and create an animation
image_files = []  # Store the images to create a gif
model.train()

with tqdm(range(num_epochs), unit="steps") as tqdms:
    for i in tqdms:
        out = model(kspace_mri_2D)  # Forward pass

        loss = torch.nn.functional.l1_loss(out, mri_2D[None])  # Compute loss
        tqdms.set_postfix({"loss": loss.item()})  # Update progress bar
        losses.append(loss.item())  # Store loss value

        optimizer.zero_grad()  # Zero gradients
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        # Generate images for gif
        hashed = joblib.hash((i, "learn_traj", time.time()))
        filename = f"{tmp.NamedTemporaryFile().name}.png"
        fig, axs = plt.subplots(2, 2, figsize=(10, 10))
        plot_state(
            axs,
            mri_2D,
            init_traj,
            out,
            losses,
            save_name=filename,
        )
        image_files.append(filename)


# Make a GIF of all images.
imgs = [Image.open(img) for img in image_files]
imgs[0].save(
    "mrinufft_learn_unet.gif",
    save_all=True,
    append_images=imgs[1:],
    optimize=False,
    duration=2,
    loop=0,
)

# sphinx_gallery_thumbnail_path = 'generated/autoexamples/GPU/images/mrinufft_learn_unet.gif'
100%|██████████| 100/100 [01:14<00:00,  1.35steps/s, loss=0.0454]

mrinufft_learn_unet_spiral.gif

Reconstruction from partially trained U-Net model

In [137]:
model.eval()
new_recon = model(kspace_mri_2D)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
plot_state(axs, mri_2D, init_traj, new_recon, losses)
plt.show()
No description has been provided for this image
In [138]:
mri_2D_spiral = mri_2D.squeeze().numpy()
new_recon_2D_spiral = new_recon.squeeze().detach().numpy()
SSIM_spiral = ssim((new_recon_2D_spiral), (mri_2D_spiral))
In [139]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Left
axes[0].imshow(np.abs(new_recon_2D_spiral), cmap='gray')
axes[0].set_title(f"Reconstruction 2D spiral : 100 epochs ")
axes[0].axis('off')

# Right
axes[1].imshow(np.abs(mri_2D_spiral), cmap='gray')
axes[1].set_title(f'mri_2D_spiral : SSIM = {SSIM_spiral:.3f}')
axes[1].axis('off')

plt.tight_layout()
plt.show()
No description has been provided for this image
2D radial spokes
In [140]:
init_traj = initialize_2D_radial(32, 256).reshape(-1, 2).astype(np.float32)
model = Model(init_traj)
model.eval()
Out[140]:
Model(
  (operator): MRINufftAutoGrad()
  (unet): Unet(
    (down_sample_layers): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
    )
    (conv): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2d(p=0.0, inplace=False)
        (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): LeakyReLU(negative_slope=0.2, inplace=True)
        (7): Dropout2d(p=0.0, inplace=False)
      )
    )
    (up_conv): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): Sequential(
        (0): ConvBlock(
          (layers): Sequential(
            (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.2, inplace=True)
            (3): Dropout2d(p=0.0, inplace=False)
            (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (6): LeakyReLU(negative_slope=0.2, inplace=True)
            (7): Dropout2d(p=0.0, inplace=False)
          )
        )
        (1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (up_transpose_conv): ModuleList(
      (0): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (1): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (2): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (3): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
    )
  )
)

Get the image on which we will train our U-Net Model

In [141]:
mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))[
    None
]
mri_2D = mri_2D / torch.mean(mri_2D)
kspace_mri_2D = model.operator.op(mri_2D)

# Before training, here is the simple reconstruction we have using a
# density compensated adjoint.
dc_adjoint = model.operator.adj_op(kspace_mri_2D)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
plot_state(axs, mri_2D, init_traj, dc_adjoint)
No description has been provided for this image

Start training loop

In [142]:
num_epochs = 100
optimizer = torch.optim.RAdam(model.parameters(), lr=1e-3)
losses = []  # Store the loss values and create an animation
image_files = []  # Store the images to create a gif
model.train()

with tqdm(range(num_epochs), unit="steps") as tqdms:
    for i in tqdms:
        out = model(kspace_mri_2D)  # Forward pass

        loss = torch.nn.functional.l1_loss(out, mri_2D[None])  # Compute loss
        tqdms.set_postfix({"loss": loss.item()})  # Update progress bar
        losses.append(loss.item())  # Store loss value

        optimizer.zero_grad()  # Zero gradients
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        # Generate images for gif
        hashed = joblib.hash((i, "learn_traj", time.time()))
        filename = f"{tmp.NamedTemporaryFile().name}.png"
        fig, axs = plt.subplots(2, 2, figsize=(10, 10))
        plot_state(
            axs,
            mri_2D,
            init_traj,
            out,
            losses,
            save_name=filename,
        )
        image_files.append(filename)


# Make a GIF of all images.
imgs = [Image.open(img) for img in image_files]
imgs[0].save(
    "mrinufft_learn_unet.gif",
    save_all=True,
    append_images=imgs[1:],
    optimize=False,
    duration=2,
    loop=0,
)

# sphinx_gallery_thumbnail_path = 'generated/autoexamples/GPU/images/mrinufft_learn_unet.gif'
100%|██████████| 100/100 [01:14<00:00,  1.35steps/s, loss=0.0583]

mrinufft_learn_unet_radial.gif

Reconstruction from partially trained U-Net model

In [143]:
model.eval()
new_recon = model(kspace_mri_2D)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
plot_state(axs, mri_2D, init_traj, new_recon, losses)
plt.show()
No description has been provided for this image
In [144]:
mri_2D_radial = mri_2D.squeeze().numpy()
new_recon_2D_radial = new_recon.squeeze().detach().numpy()
SSIM_radial = ssim((new_recon_2D_radial), (mri_2D_radial))
In [145]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Left
axes[0].imshow(np.abs(new_recon_2D_radial), cmap='gray')
axes[0].set_title(f"Reconstruction 2D radial : 100 epochs")
axes[0].axis('off')

# Right
axes[1].imshow(np.abs(mri_2D_radial), cmap='gray')
axes[1].set_title(f'mri_2D_radial : SSIM = {SSIM_radial:.3f}')
axes[1].axis('off')

plt.tight_layout()
plt.show()
No description has been provided for this image
2D radial spokes 10 epochs
In [146]:
init_traj = initialize_2D_radial(32, 256).reshape(-1, 2).astype(np.float32)
model = Model(init_traj)
model.eval()
Out[146]:
Model(
  (operator): MRINufftAutoGrad()
  (unet): Unet(
    (down_sample_layers): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
    )
    (conv): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2d(p=0.0, inplace=False)
        (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): LeakyReLU(negative_slope=0.2, inplace=True)
        (7): Dropout2d(p=0.0, inplace=False)
      )
    )
    (up_conv): ModuleList(
      (0): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (1): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (2): ConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
          (3): Dropout2d(p=0.0, inplace=False)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (6): LeakyReLU(negative_slope=0.2, inplace=True)
          (7): Dropout2d(p=0.0, inplace=False)
        )
      )
      (3): Sequential(
        (0): ConvBlock(
          (layers): Sequential(
            (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.2, inplace=True)
            (3): Dropout2d(p=0.0, inplace=False)
            (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (6): LeakyReLU(negative_slope=0.2, inplace=True)
            (7): Dropout2d(p=0.0, inplace=False)
          )
        )
        (1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (up_transpose_conv): ModuleList(
      (0): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (1): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (2): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
      (3): TransposeConvBlock(
        (layers): Sequential(
          (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): LeakyReLU(negative_slope=0.2, inplace=True)
        )
      )
    )
  )
)

Get the image on which we will train our U-Net Model

In [148]:
mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))[
    None
]
mri_2D = mri_2D / torch.mean(mri_2D)
kspace_mri_2D = model.operator.op(mri_2D)

# Before training, here is the simple reconstruction we have using a
# density compensated adjoint.
dc_adjoint = model.operator.adj_op(kspace_mri_2D)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
plot_state(axs, mri_2D, init_traj, dc_adjoint)
No description has been provided for this image

Start training loop

In [149]:
num_epochs = 10
optimizer = torch.optim.RAdam(model.parameters(), lr=1e-3)
losses = []  # Store the loss values and create an animation
image_files = []  # Store the images to create a gif
model.train()

with tqdm(range(num_epochs), unit="steps") as tqdms:
    for i in tqdms:
        out = model(kspace_mri_2D)  # Forward pass

        loss = torch.nn.functional.l1_loss(out, mri_2D[None])  # Compute loss
        tqdms.set_postfix({"loss": loss.item()})  # Update progress bar
        losses.append(loss.item())  # Store loss value

        optimizer.zero_grad()  # Zero gradients
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        # Generate images for gif
        hashed = joblib.hash((i, "learn_traj", time.time()))
        filename = f"{tmp.NamedTemporaryFile().name}.png"
        fig, axs = plt.subplots(2, 2, figsize=(10, 10))
        plot_state(
            axs,
            mri_2D,
            init_traj,
            out,
            losses,
            save_name=filename,
        )
        image_files.append(filename)


# Make a GIF of all images.
imgs = [Image.open(img) for img in image_files]
imgs[0].save(
    "mrinufft_learn_unet.gif",
    save_all=True,
    append_images=imgs[1:],
    optimize=False,
    duration=2,
    loop=0,
)

# sphinx_gallery_thumbnail_path = 'generated/autoexamples/GPU/images/mrinufft_learn_unet.gif'
100%|██████████| 10/10 [00:08<00:00,  1.17steps/s, loss=0.625]

mrinufft_learn_unet_radial_10.gif

Reconstruction from partially trained U-Net model

In [150]:
model.eval()
new_recon = model(kspace_mri_2D)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
plot_state(axs, mri_2D, init_traj, new_recon, losses)
plt.show()
No description has been provided for this image
In [151]:
mri_2D_radial_10 = mri_2D.squeeze().numpy()
new_recon_2D_radial_10 = new_recon.squeeze().detach().numpy()
SSIM_radial_10 = ssim((new_recon_2D_radial_10), (mri_2D_radial_10))
In [152]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Left
axes[0].imshow(np.abs(new_recon_2D_radial_10), cmap='gray')
axes[0].set_title(f"Reconstruction 2D radial : 10 epochs")
axes[0].axis('off')

# Right
axes[1].imshow(np.abs(mri_2D_radial_10), cmap='gray')
axes[1].set_title(f'mri_2D_radial : SSIM = {SSIM_radial_10:.3f}')
axes[1].axis('off')

plt.tight_layout()
plt.show()
No description has been provided for this image

References¶

.. [fastmri] O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015. https://github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py